package utils;

import com.alibaba.druid.pool.DruidDataSource;
import com.fasterxml.jackson.databind.ObjectMapper;
import com.fasterxml.jackson.databind.node.ObjectNode;
import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;

import java.sql.*;
import java.time.LocalDateTime;
import java.time.format.DateTimeFormatter;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.concurrent.Executors;
import java.util.concurrent.ScheduledExecutorService;
import java.util.concurrent.TimeUnit;
import java.util.regex.Matcher;
import java.util.regex.Pattern;

public class JDBCUtils {
    private static DruidDataSource dataSource = null;
    private static ThreadLocal<ConnectionWrapper> connectionThreadLocal = new ThreadLocal<>();
    private static final Logger logger = LogManager.getLogger(JDBCUtils.class);
    private static final ScheduledExecutorService scheduler = Executors.newScheduledThreadPool(1);
    private static final ObjectMapper objectMapper = new ObjectMapper();
    private static final Pattern paramPattern = Pattern.compile("\\?");
    private static final DateTimeFormatter dateFormatter = DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");

    private static void initializeDataSource() {
        if (dataSource == null) {
            synchronized (JDBCUtils.class) {
                if (dataSource == null) {
                    dataSource = new DruidDataSource();
                    dataSource.setUrl("jdbc:mysql://localhost:3306/temp-gui2");
                    dataSource.setUsername("root");
                    dataSource.setPassword("root");
                    dataSource.setInitialSize(5);
                    dataSource.setMinIdle(5);
                    dataSource.setMaxActive(20);
                    try {
                        dataSource.setFilters("stat");
                        logger.info("Database connection pool initialized at {}",
                                LocalDateTime.now().format(dateFormatter));
                    } catch (SQLException e) {
                        logger.error("Initialize Druid connection pool failed at {}",
                                LocalDateTime.now().format(dateFormatter), e);
                        throw new RuntimeException(e);
                    }
                }
            }
        }
    }

    public static Connection getConnection() {
        ConnectionWrapper wrapper = connectionThreadLocal.get();
        if (wrapper == null || wrapper.isExpired()) {
            if (dataSource == null) {
                initializeDataSource();
            }
            try {
                Connection conn = dataSource.getConnection();
                wrapper = new ConnectionWrapper(conn);
                connectionThreadLocal.set(wrapper);
                scheduleConnectionRelease(wrapper);
                logger.debug("New database connection created at {}",
                        LocalDateTime.now().format(dateFormatter));
            } catch (SQLException e) {
                logger.error("Get database connection failed at {}",
                        LocalDateTime.now().format(dateFormatter), e);
                throw new RuntimeException("Get database connection failed", e);
            }
        }
        wrapper.renew();
        return wrapper.getConnection();
    }

    private static void scheduleConnectionRelease(ConnectionWrapper wrapper) {
        scheduler.schedule(() -> {
            if (wrapper.isExpired()) {
                closeConnection();
                logger.debug("Expired connection closed at {}",
                        LocalDateTime.now().format(dateFormatter));
            }
        }, 5, TimeUnit.SECONDS);
    }

    public static void closeConnection() {
        ConnectionWrapper wrapper = connectionThreadLocal.get();
        if (wrapper != null) {
            try {
                wrapper.getConnection().close();
                logger.debug("Database connection closed at {}",
                        LocalDateTime.now().format(dateFormatter));
            } catch (SQLException e) {
                logger.error("Close database connection failed at {}",
                        LocalDateTime.now().format(dateFormatter), e);
            } finally {
                connectionThreadLocal.remove();
            }
        }
    }


    public static Object execute(String sql, Object... params) throws SQLException {
        // Record start time
        LocalDateTime startTime = LocalDateTime.now();
        String formattedStartTime = startTime.format(dateFormatter);

        Connection connection = getConnection();

        ObjectNode logEntry = objectMapper.createObjectNode();
        logEntry.put("prePareStatementSql", sql);
        logEntry.put("startTime", formattedStartTime);

        // Create parameter mapping
        Map<String, Object> paramMap = createParamMap(sql, params);
        logEntry.set("params", objectMapper.valueToTree(paramMap));

        PreparedStatement stmt = connection.prepareStatement(sql);
        // Set parameters
        for (int i = 0; i < params.length; i++) {
            stmt.setObject(i + 1, params[i]);
        }

        boolean isQuery = sql.trim().toLowerCase().startsWith("select");
        String actualSql = replaceSqlParams(sql, params);

        if (isQuery) {
            ResultSet rs = stmt.executeQuery();

            // Log execution info
            LocalDateTime endTime = LocalDateTime.now();
            String formattedEndTime = endTime.format(dateFormatter);
            long executionTimeMs = java.time.Duration.between(startTime, endTime).toMillis();

            logEntry.put("type", "query");
            logEntry.put("actualSql", actualSql);
            logEntry.put("endTime", formattedEndTime);
            logEntry.put("executionTimeMs", executionTimeMs);

            logger.info("SQL Execution Log: {}", logEntry.toString());
            return rs;
        } else {
            int affected = stmt.executeUpdate();
            // Log execution info
            LocalDateTime endTime = LocalDateTime.now();
            String formattedEndTime = endTime.format(dateFormatter);
            long executionTimeMs = java.time.Duration.between(startTime, endTime).toMillis();
            ObjectNode result = objectMapper.createObjectNode();
            result.put("affectedRows", affected);
            logEntry.put("type", "update");
            logEntry.put("affectedRows", affected);
            logEntry.put("actualSql", actualSql);
            logEntry.put("endTime", formattedEndTime);
            logEntry.put("executionTimeMs", executionTimeMs);
            logger.info("SQL Execution Log: {}", logEntry.toString());
            return affected;
        }
    }


    private static Map<String, Object> createParamMap(String sql, Object[] params) {
        Map<String, Object> paramMap = new LinkedHashMap<>();
        Matcher matcher = paramPattern.matcher(sql);
        int paramIndex = 0;

        while (matcher.find() && paramIndex < params.length) {
            // 获取?前后的上下文
            int start = Math.max(0, matcher.start() - 20);
            int end = Math.min(sql.length(), matcher.end() + 20);
            String context = sql.substring(start, end);

            // 尝试提取参数名
            String paramName = extractParamName(context);
            String key = paramName != null ?
                    String.format("%s (param%d)", paramName, paramIndex + 1) :
                    String.format("param%d", paramIndex + 1);

            Object value = params[paramIndex];
            paramMap.put(key, value != null ? value : "null");
            paramIndex++;
        }
        return paramMap;
    }

    private static String extractParamName(String context) {
        Pattern pattern = Pattern.compile("(WHERE|AND|OR|SET|INSERT|UPDATE|DELETE)\\s+([\\w_]+)\\s*=\\s*\\?",
                Pattern.CASE_INSENSITIVE);
        Matcher matcher = pattern.matcher(context);
        if (matcher.find()) {
            return matcher.group(2);
        }
        return null;
    }

    private static String replaceSqlParams(String sql, Object[] params) {
        StringBuilder result = new StringBuilder(sql);
        int offset = 0;

        for (int i = 0; i < params.length; i++) {
            int questionMarkIndex = result.indexOf("?", offset);
            if (questionMarkIndex == -1) break;

            String paramValue = params[i] == null ? "NULL" :
                    params[i] instanceof String || params[i] instanceof Date ?
                            "'" + params[i] + "'" : params[i].toString();

            result.replace(questionMarkIndex, questionMarkIndex + 1, paramValue);
            offset = questionMarkIndex + paramValue.length();
        }

        return result.toString();
    }

    public static void shutdown() {
        logger.info("Initiating JDBCUtils shutdown at {}",
                LocalDateTime.now().format(dateFormatter));

        scheduler.shutdownNow();
        try {
            if (!scheduler.awaitTermination(5, TimeUnit.SECONDS)) {
                logger.warn("Scheduler did not terminate within 5 seconds at {}",
                        LocalDateTime.now().format(dateFormatter));
            }
        } catch (InterruptedException e) {
            Thread.currentThread().interrupt();
            logger.error("Shutdown interrupted at {}",
                    LocalDateTime.now().format(dateFormatter), e);
        }

        if (dataSource != null) {
            dataSource.close();
            logger.info("Database connection pool closed at {}",
                    LocalDateTime.now().format(dateFormatter));
        }
    }

    private static class ConnectionWrapper {
        private final Connection connection;
        private long lastAccessTime;
        private static final long TIMEOUT = 5000; // 5 seconds timeout

        public ConnectionWrapper(Connection connection) {
            this.connection = connection;
            this.lastAccessTime = System.currentTimeMillis();
        }

        public Connection getConnection() {
            return connection;
        }

        public void renew() {
            this.lastAccessTime = System.currentTimeMillis();
        }

        public boolean isExpired() {
            return System.currentTimeMillis() - lastAccessTime > TIMEOUT;
        }
    }
}
